0dd97230690f299f4aa86a4c510de88ad3d6b63d,src/main/java/ml/shifu/shifu/core/dtrain/nn/NNWorker.java,NNWorker,load,#GuaguaWritableAdapter#GuaguaWritableAdapter#WorkerContext#,53
Before Change
// if fixInitialInput = false, we only compare random value with baggingSampleRate to avoid parsing data.
// if fixInitialInput = true, we should use hashcode after parsing.
double baggingSampleRate = super.modelConfig.getBaggingSampleRate();
if(!super.modelConfig.isFixInitialInput() && Double.compare(Math.random(), baggingSampleRate) >= 0) {
// for negative tags, do sampleNegOnly logic
if(modelConfig.getTrain().getSampleNegOnly()) {
if(modelConfig.isRegression() && Double.compare(ideal[0] + 0d, 0d) == 0) {
return;
}
} else {
return;// normal sampling
}
}
// if fixInitialInput = true, we should use hashcode to sample.
long longBaggingSampleRate = Double.valueOf(baggingSampleRate * 100).longValue();
if(super.modelConfig.isFixInitialInput() && hashcode % 100 >= longBaggingSampleRate) {
// for negative tags, do sampleNegOnly logic
if(modelConfig.getTrain().getSampleNegOnly()) {
if(modelConfig.isRegression() && Double.compare(ideal[0] + 0d, 0d) == 0) {
return;
}
} else {
return;// normal sampling
}
}
// count stats after sampling
super.sampleCount += 1;
FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));
if(modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
// Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
} else {
pair.setSignificance(significance);
}
boolean isTesting = false;
if(workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
isTesting = (Boolean) workerContext.getAttachment();
}
addDataPairToDataSet(hashcode, pair, isTesting);
}
/*
After Change
}
// if only sample negative, no matter bagging or replacement, do sampling here.
if(modelConfig.getTrain().getSampleNegOnly() // sample negative enabled
&& (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain()
.isOneVsAll())) // regression or onevsall
&& Double.compare(ideal[0] + 0.01d, 0d) == 0 // negative record
&& (!this.modelConfig.isFixInitialInput() && Double.compare(Math.random(),
this.modelConfig.getBaggingSampleRate()) >= 0)) {
return;
}
if(modelConfig.getTrain().getSampleNegOnly()// sample negative enabled
&& (modelConfig.isRegression() || (modelConfig.isClassification() && modelConfig.getTrain()
.isOneVsAll()))// regression or onevsall
&& (Double.compare(ideal[0] + 0.01d, 0d) == 0 // negative record
&& this.modelConfig.isFixInitialInput() && hashcode % 100 >= Double.valueOf(
this.modelConfig.getBaggingSampleRate() * 100).longValue())) {
return;
}
FloatMLDataPair pair = new BasicFloatMLDataPair(new BasicFloatMLData(inputs), new BasicFloatMLData(ideal));
// up sampling logic, just add more weights while bagging sampling rate is still not changed
if(modelConfig.isRegression() && isUpSampleEnabled() && Double.compare(ideal[0], 1d) == 0) {
// Double.compare(ideal[0], 1d) == 0 means positive tags; sample + 1 to avoid sample count to 0
pair.setSignificance(significance * (super.upSampleRng.sample() + 1));
} else {
pair.setSignificance(significance);
}
boolean isValidation = false;
if(workerContext.getAttachment() != null && workerContext.getAttachment() instanceof Boolean) {
isValidation = (Boolean) workerContext.getAttachment();
}
boolean isInTraining = addDataPairToDataSet(hashcode, pair, isValidation);
// do bagging sampling only for training data,
if(isInTraining) {
float subsampleWeights = sampleWeights(pair.getIdealArray()[0]);
if(isPositive(pair.getIdealArray()[0])) {
this.positiveSelectedTrainCount += subsampleWeights * 1L;
} else {
this.negativeSelectedTrainCount += subsampleWeights * 1L;